{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Spark SQL Agent\n",
    "\n",
    "This notebook shows how to use agents to interact with a Spark SQL. Similar to [SQL Database Agent](https://python.langchain.com/en/latest/modules/agents/toolkits/examples/sql_database.html), it is designed to address general inquiries about Spark SQL and facilitate error recovery.\n",
    "\n",
    "**NOTE: Note that, as this agent is in active development, all answers might not be correct. Additionally, it is not guaranteed that the agent won't perform DML statements on your Spark cluster given certain questions. Be careful running it on sensitive data!**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.agents import create_spark_sql_agent\n",
    "from langchain.agents.agent_toolkits import SparkSQLToolkit\n",
    "from langchain.chat_models import ChatOpenAI\n",
    "from langchain.utilities.spark_sql import SparkSQL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "23/05/18 16:03:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
      "|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|\n",
      "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
      "|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|\n",
      "|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|\n",
      "|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| null|       S|\n",
      "|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|\n",
      "|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| null|       S|\n",
      "|          6|       0|     3|    Moran, Mr. James|  male|null|    0|    0|          330877| 8.4583| null|       Q|\n",
      "|          7|       0|     1|McCarthy, Mr. Tim...|  male|54.0|    0|    0|           17463|51.8625|  E46|       S|\n",
      "|          8|       0|     3|Palsson, Master. ...|  male| 2.0|    3|    1|          349909| 21.075| null|       S|\n",
      "|          9|       1|     3|Johnson, Mrs. Osc...|female|27.0|    0|    2|          347742|11.1333| null|       S|\n",
      "|         10|       1|     2|Nasser, Mrs. Nich...|female|14.0|    1|    0|          237736|30.0708| null|       C|\n",
      "|         11|       1|     3|Sandstrom, Miss. ...|female| 4.0|    1|    1|         PP 9549|   16.7|   G6|       S|\n",
      "|         12|       1|     1|Bonnell, Miss. El...|female|58.0|    0|    0|          113783|  26.55| C103|       S|\n",
      "|         13|       0|     3|Saundercock, Mr. ...|  male|20.0|    0|    0|       A/5. 2151|   8.05| null|       S|\n",
      "|         14|       0|     3|Andersson, Mr. An...|  male|39.0|    1|    5|          347082| 31.275| null|       S|\n",
      "|         15|       0|     3|Vestrom, Miss. Hu...|female|14.0|    0|    0|          350406| 7.8542| null|       S|\n",
      "|         16|       1|     2|Hewlett, Mrs. (Ma...|female|55.0|    0|    0|          248706|   16.0| null|       S|\n",
      "|         17|       0|     3|Rice, Master. Eugene|  male| 2.0|    4|    1|          382652| 29.125| null|       Q|\n",
      "|         18|       1|     2|Williams, Mr. Cha...|  male|null|    0|    0|          244373|   13.0| null|       S|\n",
      "|         19|       0|     3|Vander Planke, Mr...|female|31.0|    1|    0|          345763|   18.0| null|       S|\n",
      "|         20|       1|     3|Masselmani, Mrs. ...|female|null|    0|    0|            2649|  7.225| null|       C|\n",
      "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "\n",
    "spark = SparkSession.builder.getOrCreate()\n",
    "schema = \"langchain_example\"\n",
    "spark.sql(f\"CREATE DATABASE IF NOT EXISTS {schema}\")\n",
    "spark.sql(f\"USE {schema}\")\n",
    "csv_file_path = \"titanic.csv\"\n",
    "table = \"titanic\"\n",
    "spark.read.csv(csv_file_path, header=True, inferSchema=True).write.saveAsTable(table)\n",
    "spark.table(table).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note, you can also connect to Spark via Spark connect. For example:\n",
    "# db = SparkSQL.from_uri(\"sc://localhost:15002\", schema=schema)\n",
    "spark_sql = SparkSQL(schema=schema)\n",
    "llm = ChatOpenAI(temperature=0)\n",
    "toolkit = SparkSQLToolkit(db=spark_sql, llm=llm)\n",
    "agent_executor = create_spark_sql_agent(llm=llm, toolkit=toolkit, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example: describing a table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mAction: list_tables_sql_db\n",
      "Action Input: \u001b[0m\n",
      "Observation: \u001b[38;5;200m\u001b[1;3mtitanic\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mI found the titanic table. Now I need to get the schema and sample rows for the titanic table.\n",
      "Action: schema_sql_db\n",
      "Action Input: titanic\u001b[0m\n",
      "Observation: \u001b[33;1m\u001b[1;3mCREATE TABLE langchain_example.titanic (\n",
      "  PassengerId INT,\n",
      "  Survived INT,\n",
      "  Pclass INT,\n",
      "  Name STRING,\n",
      "  Sex STRING,\n",
      "  Age DOUBLE,\n",
      "  SibSp INT,\n",
      "  Parch INT,\n",
      "  Ticket STRING,\n",
      "  Fare DOUBLE,\n",
      "  Cabin STRING,\n",
      "  Embarked STRING)\n",
      ";\n",
      "\n",
      "/*\n",
      "3 rows from titanic table:\n",
      "PassengerId\tSurvived\tPclass\tName\tSex\tAge\tSibSp\tParch\tTicket\tFare\tCabin\tEmbarked\n",
      "1\t0\t3\tBraund, Mr. Owen Harris\tmale\t22.0\t1\t0\tA/5 21171\t7.25\tNone\tS\n",
      "2\t1\t1\tCumings, Mrs. John Bradley (Florence Briggs Thayer)\tfemale\t38.0\t1\t0\tPC 17599\t71.2833\tC85\tC\n",
      "3\t1\t3\tHeikkinen, Miss. Laina\tfemale\t26.0\t0\t0\tSTON/O2. 3101282\t7.925\tNone\tS\n",
      "*/\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mI now know the schema and sample rows for the titanic table.\n",
      "Final Answer: The titanic table has the following columns: PassengerId (INT), Survived (INT), Pclass (INT), Name (STRING), Sex (STRING), Age (DOUBLE), SibSp (INT), Parch (INT), Ticket (STRING), Fare (DOUBLE), Cabin (STRING), and Embarked (STRING). Here are some sample rows from the table: \n",
      "\n",
      "1. PassengerId: 1, Survived: 0, Pclass: 3, Name: Braund, Mr. Owen Harris, Sex: male, Age: 22.0, SibSp: 1, Parch: 0, Ticket: A/5 21171, Fare: 7.25, Cabin: None, Embarked: S\n",
      "2. PassengerId: 2, Survived: 1, Pclass: 1, Name: Cumings, Mrs. John Bradley (Florence Briggs Thayer), Sex: female, Age: 38.0, SibSp: 1, Parch: 0, Ticket: PC 17599, Fare: 71.2833, Cabin: C85, Embarked: C\n",
      "3. PassengerId: 3, Survived: 1, Pclass: 3, Name: Heikkinen, Miss. Laina, Sex: female, Age: 26.0, SibSp: 0, Parch: 0, Ticket: STON/O2. 3101282, Fare: 7.925, Cabin: None, Embarked: S\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": "'The titanic table has the following columns: PassengerId (INT), Survived (INT), Pclass (INT), Name (STRING), Sex (STRING), Age (DOUBLE), SibSp (INT), Parch (INT), Ticket (STRING), Fare (DOUBLE), Cabin (STRING), and Embarked (STRING). Here are some sample rows from the table: \\n\\n1. PassengerId: 1, Survived: 0, Pclass: 3, Name: Braund, Mr. Owen Harris, Sex: male, Age: 22.0, SibSp: 1, Parch: 0, Ticket: A/5 21171, Fare: 7.25, Cabin: None, Embarked: S\\n2. PassengerId: 2, Survived: 1, Pclass: 1, Name: Cumings, Mrs. John Bradley (Florence Briggs Thayer), Sex: female, Age: 38.0, SibSp: 1, Parch: 0, Ticket: PC 17599, Fare: 71.2833, Cabin: C85, Embarked: C\\n3. PassengerId: 3, Survived: 1, Pclass: 3, Name: Heikkinen, Miss. Laina, Sex: female, Age: 26.0, SibSp: 0, Parch: 0, Ticket: STON/O2. 3101282, Fare: 7.925, Cabin: None, Embarked: S'"
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent_executor.run(\"Describe the titanic table\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example: running queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mAction: list_tables_sql_db\n",
      "Action Input: \u001b[0m\n",
      "Observation: \u001b[38;5;200m\u001b[1;3mtitanic\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mI should check the schema of the titanic table to see if there is an age column.\n",
      "Action: schema_sql_db\n",
      "Action Input: titanic\u001b[0m\n",
      "Observation: \u001b[33;1m\u001b[1;3mCREATE TABLE langchain_example.titanic (\n",
      "  PassengerId INT,\n",
      "  Survived INT,\n",
      "  Pclass INT,\n",
      "  Name STRING,\n",
      "  Sex STRING,\n",
      "  Age DOUBLE,\n",
      "  SibSp INT,\n",
      "  Parch INT,\n",
      "  Ticket STRING,\n",
      "  Fare DOUBLE,\n",
      "  Cabin STRING,\n",
      "  Embarked STRING)\n",
      ";\n",
      "\n",
      "/*\n",
      "3 rows from titanic table:\n",
      "PassengerId\tSurvived\tPclass\tName\tSex\tAge\tSibSp\tParch\tTicket\tFare\tCabin\tEmbarked\n",
      "1\t0\t3\tBraund, Mr. Owen Harris\tmale\t22.0\t1\t0\tA/5 21171\t7.25\tNone\tS\n",
      "2\t1\t1\tCumings, Mrs. John Bradley (Florence Briggs Thayer)\tfemale\t38.0\t1\t0\tPC 17599\t71.2833\tC85\tC\n",
      "3\t1\t3\tHeikkinen, Miss. Laina\tfemale\t26.0\t0\t0\tSTON/O2. 3101282\t7.925\tNone\tS\n",
      "*/\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mThere is an Age column in the titanic table. I should write a query to calculate the average age and then find the square root of the result.\n",
      "Action: query_checker_sql_db\n",
      "Action Input: SELECT SQRT(AVG(Age)) as square_root_of_avg_age FROM titanic\u001b[0m\n",
      "Observation: \u001b[31;1m\u001b[1;3mThe original query seems to be correct. Here it is again:\n",
      "\n",
      "SELECT SQRT(AVG(Age)) as square_root_of_avg_age FROM titanic\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mThe query is correct, so I can execute it to find the square root of the average age.\n",
      "Action: query_sql_db\n",
      "Action Input: SELECT SQRT(AVG(Age)) as square_root_of_avg_age FROM titanic\u001b[0m\n",
      "Observation: \u001b[36;1m\u001b[1;3m[('5.449689683556195',)]\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mI now know the final answer\n",
      "Final Answer: The square root of the average age is approximately 5.45.\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": "'The square root of the average age is approximately 5.45.'"
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent_executor.run(\"whats the square root of the average age?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mAction: list_tables_sql_db\n",
      "Action Input: \u001b[0m\n",
      "Observation: \u001b[38;5;200m\u001b[1;3mtitanic\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mI should check the schema of the titanic table to see what columns are available.\n",
      "Action: schema_sql_db\n",
      "Action Input: titanic\u001b[0m\n",
      "Observation: \u001b[33;1m\u001b[1;3mCREATE TABLE langchain_example.titanic (\n",
      "  PassengerId INT,\n",
      "  Survived INT,\n",
      "  Pclass INT,\n",
      "  Name STRING,\n",
      "  Sex STRING,\n",
      "  Age DOUBLE,\n",
      "  SibSp INT,\n",
      "  Parch INT,\n",
      "  Ticket STRING,\n",
      "  Fare DOUBLE,\n",
      "  Cabin STRING,\n",
      "  Embarked STRING)\n",
      ";\n",
      "\n",
      "/*\n",
      "3 rows from titanic table:\n",
      "PassengerId\tSurvived\tPclass\tName\tSex\tAge\tSibSp\tParch\tTicket\tFare\tCabin\tEmbarked\n",
      "1\t0\t3\tBraund, Mr. Owen Harris\tmale\t22.0\t1\t0\tA/5 21171\t7.25\tNone\tS\n",
      "2\t1\t1\tCumings, Mrs. John Bradley (Florence Briggs Thayer)\tfemale\t38.0\t1\t0\tPC 17599\t71.2833\tC85\tC\n",
      "3\t1\t3\tHeikkinen, Miss. Laina\tfemale\t26.0\t0\t0\tSTON/O2. 3101282\t7.925\tNone\tS\n",
      "*/\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mI can use the titanic table to find the oldest survived passenger. I will query the Name and Age columns, filtering by Survived and ordering by Age in descending order.\n",
      "Action: query_checker_sql_db\n",
      "Action Input: SELECT Name, Age FROM titanic WHERE Survived = 1 ORDER BY Age DESC LIMIT 1\u001b[0m\n",
      "Observation: \u001b[31;1m\u001b[1;3mSELECT Name, Age FROM titanic WHERE Survived = 1 ORDER BY Age DESC LIMIT 1\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mThe query is correct. Now I will execute it to find the oldest survived passenger.\n",
      "Action: query_sql_db\n",
      "Action Input: SELECT Name, Age FROM titanic WHERE Survived = 1 ORDER BY Age DESC LIMIT 1\u001b[0m\n",
      "Observation: \u001b[36;1m\u001b[1;3m[('Barkworth, Mr. Algernon Henry Wilson', '80.0')]\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
      "Final Answer: The oldest survived passenger is Barkworth, Mr. Algernon Henry Wilson, who was 80 years old.\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": "'The oldest survived passenger is Barkworth, Mr. Algernon Henry Wilson, who was 80 years old.'"
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent_executor.run(\"What's the name of the oldest survived passenger?\")"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
